# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------


"""
This merges convergence debugging per activation summary files into per-step summary files.

Both PyTorch and ORT run result directories are needed, check `StatisticsSubscriber` usage from
[ORTModule_Convergence_Notes](docs/ORTModule_Convergence_Notes.md) for how to export the results.

Since ORT and PyTorch run the compute in different orders (e.g. different typological orders sort
implementations), when we generate a per-step summary, we want the summary to be comparable between
ORT and PyTorch run. So during the merge, the same typological order is used.

Example:
    python merge_activation_summary.py --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output

"""

import argparse
import logging
import os
import shutil
from pathlib import Path

logger = logging.getLogger(__name__)


def generate_summaries_per_step(args):
    pt_dir = args.pt_dir
    ort_dir = args.ort_dir
    output_path = Path(args.output_dir)

    output_path.mkdir(parents=True, exist_ok=False)

    # We should use the order.txt generated by PyTorch run, which means, we follow the PyTorch typological order to compare
    # activation results. Here we assume to get the order.txt from pt_dir/step_0/order.txt
    topo_order_file_path = Path(f"{pt_dir}/step_0/order.txt")

    src_ort_path = Path(ort_dir)
    src_pt_path = Path(pt_dir)

    merge_ort_path = output_path / "merge_ort"
    merge_pt_path = output_path / "merge_pt"

    def generate_summary_per_step(topo_order_file_path: Path, dump_src_path: Path, merge_dest_path: Path):
        logger.warning(
            "Start generating summary per step for [%s] following typological order in [%s]",
            dump_src_path.as_posix(),
            topo_order_file_path.as_posix(),
        )
        with topo_order_file_path.open(mode="r", encoding="utf-8") as order_file:
            tensor_name_in_order = order_file.readlines()

        if merge_dest_path.exists():
            shutil.rmtree(merge_dest_path.as_posix())
        merge_dest_path.mkdir(parents=True, exist_ok=False)

        for dump_step_path in dump_src_path.iterdir():
            if dump_step_path.is_dir():
                step_name = dump_step_path.name
                merge_filename_for_sub_dir = merge_dest_path / f"{step_name}_.txt"
                # Open merge_filename_for_sub_dir in write mode
                with merge_filename_for_sub_dir.open(mode="w", encoding="utf-8") as outfile:
                    for filename in tensor_name_in_order:
                        filename = filename.rstrip("\n")  # noqa: PLW2901
                        full_filename = dump_step_path / filename
                        if not full_filename.exists():
                            # Be noted that some tensor handled in PyTorch might be missing in ORT graph
                            # (if the activation is not used by others, which is pruned during export)
                            logger.warning("tensor %s not exist", full_filename)
                            continue

                        with full_filename.open(mode="r", encoding="utf-8") as infile:
                            outfile.write(infile.read())

                        outfile.write("\n")

        logger.warning(
            "Finish generating summary per step for [%s] following typological order in [%s], merged files are in [%s]",
            dump_src_path.as_posix(),
            topo_order_file_path.as_posix(),
            merge_dest_path.as_posix(),
        )

    generate_summary_per_step(topo_order_file_path, src_pt_path, merge_pt_path)
    generate_summary_per_step(topo_order_file_path, src_ort_path, merge_ort_path)


def parse_arguments():
    """Parse arguments
    Merge per-step summary from dumped files.

    Data are collected with `StatisticsSubscriber` for ORT and PyTorch runs.
    Required parameters include
    > dump root folders for ORT and PyTorch.
    > output folder.

    Returns:
        Namespace: arguments
    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--pt_dir",
        required=True,
        type=str,
        help="Root of input directory of PyTorch run result of activation dump.",
    )

    parser.add_argument(
        "--ort_dir",
        required=True,
        type=str,
        help="Root of input directory of ORTModule run result of activation dump.",
    )

    parser.add_argument(
        "--output_dir",
        required=True,
        type=str,
        help="Root of output directory for generated PyTorch/ORTModule per-step summaries.",
    )

    parser.add_argument(
        "--overwrite",
        required=False,
        action="store_true",
        help="Overwrite exists output files.",
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_arguments()

    if os.path.exists(args.output_dir):
        if args.overwrite:
            logger.warning("Output directory %s already exists, overwriting it.", args.output_dir)
            shutil.rmtree(args.output_dir)
        else:
            raise FileExistsError(
                f"Output directory {args.output_dir} already exists. Enable --overwrite to allow overwriting it."
            )

    logger.info("Arguments: %s", str(args))
    generate_summaries_per_step(args)


if __name__ == "__main__":
    main()
